import torch
import torchvision
import torch.nn as nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch import optim
from torch.autograd import Variable
import numpy as np
from Trace import LossScaledTrace
#from Batch_Trace import LossScaledTrace
#from ExpandedWeightsTrace import LossScaledTrace
#from Accel_Trace import LossScaledTrace
import time
import copy

folder_num = '11'

def GD(model, train_data, test_data, train_size = 6400, bs = 32, eps = 50, learning_rate=0.01, decay_rate=1.0, starting_epoch=0, K = 5, B=50):
    # N_train is the size of the whole MNIST train set, and train_size is the size of the selected train dataset.
    N_train = 60000
    N_test = 10000
    start_time = time.time()
    num_epochs = int(eps * K)
    learning_rate = learning_rate
    #learning_rate = 0.1
    #train_bs = 64
    eval_bs = 100
    test_bs = 100

    cnn = copy.deepcopy(model)

    if torch.cuda.is_available():
        print("Working on GPU")
    else:
        print("Working on CPU")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Define the batch size
    eval_data = copy.deepcopy(train_data)
    loaders = {
        'train': DataLoader(train_data,
                                             batch_size=N_train,
                                             shuffle=True,
                                             num_workers=8,
                                             pin_memory=True),
        'eval': DataLoader(eval_data,
                           batch_size=eval_bs,
                           shuffle=True,
                           num_workers=8,
                           pin_memory=True),
        'test': DataLoader(test_data,
                                            batch_size=test_bs,
                                            shuffle=True,
                                            num_workers=8,
                                            pin_memory=True),
    }

    ProductTraces = []
    Frobeniuses = []
    HessianTraces = []
    Epochs = []
    TrainLosses = []
    TestEpochs = []
    TestLosses = []
    TestAccuracies = []

    '''
    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()
            self.conv1 = nn.Sequential(
                nn.Conv2d(
                    in_channels=1,
                    out_channels=16,
                    kernel_size=5,
                    stride=1,
                    padding=2,
                ),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2),
            )
            self.conv2 = nn.Sequential(
                nn.Conv2d(16, 32, 5, 1, 2),
                nn.ReLU(),
                nn.MaxPool2d(2),
            )
            # fully connected layer, output 10 classes
            self.out = nn.Linear(32 * 7 * 7, 10)
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
            x = x.view(x.size(0), -1)
            output = self.out(x)
            return output, x    # return x for visualization

    cnn = CNN()
    cnn = cnn.to(device)
    cnn = nn.DataParallel(cnn)
    print(cnn)
    print('Parameters')
    print(cnn.parameters())
    '''
    loss_func = nn.CrossEntropyLoss()
    print(loss_func)

    optimizer = optim.SGD(cnn.parameters(), lr = learning_rate)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, decay_rate)
    print(optimizer)

    def test():
        # Test the model
        cnn.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            accuracy = 0
            for images, labels in loaders['test']:
                images = images.to(device)
                labels = labels.to(device)
                #test_output, last_layer = cnn(images)
                test_output = cnn(images)
                pred_y = torch.max(test_output, 1)[1].data.squeeze()
                accuracy += (pred_y == labels).sum().item() / float(labels.size(0))
            accuracy = accuracy / (N_test / test_bs)
            print('GD Test Accuracy of the model on the 10000 test images: %.4f' % (accuracy))
        cnn.train()
        return accuracy

    def train(num_epochs, cnn, loaders):
        #cnn.train()

        # Compute the number of parameters in the model
        total_params = sum(p.numel() for p in cnn.parameters())
        #print("The parameters are {}".format(cnn.parameters()))
        #print("The number of parameters is {}".format(total_params))
        # Train the model
        total_step = len(loaders['train'])

        for ep in range(num_epochs):
            epoch = ep + (1 + starting_epoch)
            torch.cuda.empty_cache()
            for i, (images, labels) in enumerate(loaders['train']):
                #print('i is {}'.format(i))
                # gives batch data, normalize x when iterate train_loader
                b_x = Variable(images)  # batch x
                b_y = Variable(labels)  # batch y
                b_x = b_x.to(device)
                b_y = b_y.to(device)


                output = cnn(b_x)[0]
                #output = cnn(b_x)
                loss = loss_func(output, b_y)

                # clear gradients for this training step
                optimizer.zero_grad()

                # backpropagation, compute gradients
                loss.backward()
                # Print the norm of the gradient
                if epoch % K == 0:
                    norm = 0
                    for para in cnn.parameters():
                        if para.requires_grad:
                            norm += torch.norm(para.grad)
                    print('The norm of the gradient for epoch {} is {}'.format(epoch, norm))
                # apply gradients
                optimizer.step()

                if epoch % K == 0:
                    TrainLosses.append(loss.item())

                if (i + 1) % int(train_size/bs) == 0:
                    print('Epoch [{}/{}], Step [{}/{}], Loss: {:.14f}'.format(epoch, num_epochs, i + 1, total_step, loss.item()))
                    # Print the loss-scaled gradient variance
            '''
            if (epoch + 1) % (3 * K) == 0:
                # print(epoch + 1)
                modelcopy = copy.deepcopy(cnn)
                # print(LimitNeighborBV(test_model=modelcopy, inputs=x_train, labels=y_train, d=d,
                #                                N_train=N_train, radiuses=[10e-7]))
                # print(LimitNeighborLossScaledBV(test_model=modelcopy, inputs=x_train, labels=y_train, d=d,
                # N_train=N_train, radiuses=[10e-7]))


                ProductTrace, Frobenius, HessianTrace = LossScaledTrace(test_model=modelcopy,
                                                                    train_data=train_data,
                                                                    d=total_params,
                                                                    train_size=train_size,
                                                                        B=B)
                ProductTraces.append(ProductTrace)
                Frobeniuses.append(Frobenius)
                HessianTraces.append(HessianTrace)
                Epochs.append(epoch + 1)
            '''
            if epoch % K == 0:
                # Compute the test loss
                cnn.eval()
                with torch.no_grad():
                    test_loss = 0
                    accuracy = 0
                    for images, labels in loaders['test']:
                        images = images.to(device)
                        labels = labels.to(device)
                        #outputs = cnn(images)[0]
                        test_outputs = cnn(images)[0]
                        test_loss += loss_func(test_outputs, labels).item()
                        pred_y = torch.max(test_outputs, 1)[1].data.squeeze()
                        accuracy += (pred_y == labels).sum().item() / float(labels.size(0))
                    test_loss /= len(loaders['test'])
                    accuracy /= len(loaders['test'])
                    TestEpochs.append(epoch)
                    TestLosses.append(test_loss)
                    TestAccuracies.append(accuracy)
                    print('Epoch [{}/{}], Test Loss: {:.8f}'.format(epoch, num_epochs, test_loss))
                    print('GD Test Accuracy of the model on the 10000 test images: %.4f' % (accuracy))
                cnn.train()
            #if epoch % K == 0:
            #    acc = test()
            #    TestAccuracies.append(acc)
            # Save the model
            if epoch % (10 * K) == 0:
                PATH = 'Saved01/GD{}/epoch{}.pt'.format(folder_num, epoch)
                torch.save(cnn.state_dict(), PATH)
            scheduler.step()
    train(num_epochs, cnn, loaders)


    test()
    print(f'GD RunTime: {time.time() - start_time:.2f}')

    return Epochs, ProductTraces, Frobeniuses, HessianTraces, TrainLosses, TestLosses, TestAccuracies